import os
import torch
from PIL import Image
import wandb
import numpy as np
from tqdm import tqdm
from typing import Dict, List, Optional, Union, Type

from noiser.ReSD import ReSDPipeline
from noiser.sr_noise.SRAttacker import SRAttacker
# Import the watermark attackers from wmattacker.py
from wmattacker import (
    WMAttacker,
    GaussianNoiseAttacker,
    GaussianBlurAttacker,
    JPEGAttacker,
    BrightnessAttacker,
    ContrastAttacker,
    RotateAttacker,
    BM3DAttacker,
    ScaleAttacker,
    CropAttacker,
    VAEWMAttacker,
    DiffWMAttacker
)



class WatermarkAttackFramework:
    """
    Framework for running watermark attacks on images.
    """

    def __init__(self,
                 watermarked_dir: str,
                 output_dir: str,
                 original_dir: Optional[str] = None,
                 single_attackers: Optional[Dict[str, WMAttacker]] = None,
                 combined_cases: Optional[Dict[str, Dict[str, WMAttacker]]] = None,
                 metrics_evaluator: Optional[object] = None,
                 device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
                 debug=False):
        """
        Initialize the watermark attack framework.

        Args:
            watermarked_dir: Directory containing watermarked images
            output_dir: Base directory for outputs
            original_dir: Directory with original (non-watermarked) images for evaluation
            single_attackers: Dictionary of single attacker objects
            combined_cases: Dictionary of combined attack scenarios
            metrics_evaluator: MetricsEvaluator object for evaluation
            device: Device to use for models
        """
        self.watermarked_dir = watermarked_dir
        self.output_dir = output_dir
        self.original_dir = original_dir
        self.device = device
        self.metrics_evaluator = metrics_evaluator
        self.debug = debug

        # Create output directory
        os.makedirs(output_dir, exist_ok=True)

        # Initialize attackers with defaults if not provided
        self.single_attackers = single_attackers or self._default_single_attackers()
        self.combined_cases = combined_cases or self._default_combined_cases()
        self.sr_type = ['SwinIR_x4','RealESRGAN_x4','SRResize_x4','AdcSR']

    def _default_single_attackers(self) -> Dict[str, WMAttacker]:
        """
        Create default set of single attackers.

        Returns:
            Dictionary of default attackers
        """
        att_pipe = ReSDPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16,
                                                revision="fp16")
        att_pipe.set_progress_bar_config(disable=True)
        att_pipe.to(self.device)
        attackers = {
            'diff_attacker_60': DiffWMAttacker(att_pipe, batch_size=5, noise_step=60, captions={}),
            'cheng2020-anchor_1': VAEWMAttacker('cheng2020-anchor', quality=1, metric='mse', device=self.device),
            'bmshj2018-factorized_1': VAEWMAttacker('bmshj2018-factorized', quality=1, metric='mse', device=self.device),
            'diff_attacker_20': DiffWMAttacker(att_pipe, batch_size=5, noise_step=20, captions={}),
            'cheng2020-anchor_3': VAEWMAttacker('cheng2020-anchor', quality=3, metric='mse', device=self.device),
            'bmshj2018-factorized_3': VAEWMAttacker('bmshj2018-factorized', quality=3, metric='mse',
                                                    device=self.device),
            'diff_attacker_10': DiffWMAttacker(att_pipe, batch_size=5, noise_step=10, captions={}),
            'cheng2020-anchor_5': VAEWMAttacker('cheng2020-anchor', quality=5, metric='mse', device=self.device),
            'bmshj2018-factorized_5': VAEWMAttacker('bmshj2018-factorized', quality=5, metric='mse',
                                                    device=self.device),
            'jpeg_attacker_10': JPEGAttacker(quality=10),
            'jpeg_attacker_5': JPEGAttacker(quality=5),
            'jpeg_attacker_50': JPEGAttacker(quality=50),
            'jpeg_attacker_80': JPEGAttacker(quality=80),
            'rotate_45': RotateAttacker(degree=45),
            'rotate_15': RotateAttacker(degree=15),
            'rotate_5': RotateAttacker(degree=5),
            'brightness_0.1': BrightnessAttacker(brightness=0.1),
            'contrast_0.1': ContrastAttacker(contrast=0.1),
            'Gaussian_noise_1.0': GaussianNoiseAttacker(std=1.0),
            'Gaussian_noise_0.5': GaussianNoiseAttacker(std=0.5),
            'Gaussian_noise_0.1': GaussianNoiseAttacker(std=0.1),
            'Gaussian_noise_0.05': GaussianNoiseAttacker(std=0.05),
            'Gaussian_blur': GaussianBlurAttacker(kernel_size=5, sigma=1),
            'bm3d': BM3DAttacker(0.9),
            # 'crop_0.5': CropAttacker(0.5),
            # 'gaussian_noise_0.05': GaussianNoiseAttacker(std=0.05),
            # 'gaussian_noise_0.1': GaussianNoiseAttacker(std=0.1),
            # 'gaussian_blur': GaussianBlurAttacker(kernel_size=5, sigma=1),
            # 'jpeg_quality_50': JPEGAttacker(quality=50),
            # 'jpeg_quality_80': JPEGAttacker(quality=80),
            # 'brightness_0.5': BrightnessAttacker(brightness=0.5),
            # 'brightness_1.5': BrightnessAttacker(brightness=1.5),
            # 'contrast_0.5': ContrastAttacker(contrast=0.5),
            # 'contrast_1.5': ContrastAttacker(contrast=1.5),
            # 'rotate_90': RotateAttacker(degree=90),
            # 'rotate_45': RotateAttacker(degree=45),
            # 'bm3d': BM3DAttacker(),
        }
        # Add SR attackers if possible
        attackers.update({
            'high_SwinIR_x4': SRAttacker(model_type='SwinIR',
                                    scale=4, task='classical_sr', device=self.device, input_size=(256,256),target_size=(256,256)),
            'low_SwinIR_x4': SRAttacker(model_type='SwinIR',
                                        scale=4, task='classical_sr', device=self.device, input_size=(64,64),target_size=(256,256)),
            'high_RealESRGAN_x4': SRAttacker(model_type='RealESRGAN',
                                        model_name="RealESRGAN_x4plus", device=self.device, input_size=(256,256),target_size=(256,256)),
            'high_RealESRGAN_x4_no_resize': SRAttacker(model_type='RealESRGAN',
                                             model_name="RealESRGAN_x4plus", device=self.device, input_size=(256, 256),
                                             target_size=(1024, 1024)),
            'low_RealESRGAN_x4': SRAttacker(model_type='RealESRGAN',
                                            model_name="RealESRGAN_x4plus", device=self.device, input_size=(64,64),target_size=(256,256)),
            'low_RealESRGAN_x2_input_128': SRAttacker(model_type='RealESRGAN',
                                            model_name="RealESRGAN_x2plus", model_path="./weight/RealESRGAN/RealESRGAN_x2plus.pth", device=self.device, input_size=(128, 128),
                                            target_size=(256, 256)),
            'low_RealESRGAN_x2_input_64': SRAttacker(model_type='RealESRGAN',
                                                      model_name="RealESRGAN_x2plus",
                                                      model_path="./weight/RealESRGAN/RealESRGAN_x2plus.pth",
                                                      device=self.device, input_size=(64, 64),
                                                      target_size=(256, 256)),
            'high_SRResize_x4': SRAttacker(model_type='SRResize',
                                      scale_factor=4, device=self.device, input_size=(256,256),target_size=(256,256)),
            'low_SRResize_x4': SRAttacker(model_type='SRResize',
                                          scale_factor=4, device=self.device, input_size=(64,64),target_size=(256,256)),
            'high_AdcSR': SRAttacker(model_type='AdcSR',  device=self.device, input_size=(256,256),target_size=(256,256)),
            'low_AdcSR': SRAttacker(model_type='AdcSR',  device=self.device, input_size=(64,64),target_size=(256,256)),
        })

        return attackers

    def _default_combined_cases(self) -> Dict[str, Dict[str, WMAttacker]]:
        """
        Create default set of combined attack cases.

        Returns:
            Dictionary of default combined attack cases
        """
        cases = {
            # 'noise_then_jpeg': {
            #     'gaussian_noise_0.05': GaussianNoiseAttacker(std=0.05),
            #     'jpeg_quality_50': JPEGAttacker(quality=50)
            # },
            # 'rotate_then_noise': {
            #     'rotate_45': RotateAttacker(degree=45),
            #     'gaussian_noise_0.05': GaussianNoiseAttacker(std=0.05)
            # },
            # 'multiple_distortions': {
            #     'brightness_0.1': BrightnessAttacker(brightness=0.1),
            #     'contrast_0.1': ContrastAttacker(contrast=0.1),
            #     'gaussian_blur': GaussianBlurAttacker(kernel_size=3, sigma=0.5),
            #     'jpeg_quality_10': JPEGAttacker(quality=10)
            # },
            # '2_low_realesrgan':{
            #     'low_RealESRGAN_x4_1': self.single_attackers['low_RealESRGAN_x4'],
            #     'low_RealESRGAN_x4_2': self.single_attackers['low_RealESRGAN_x4']
            # },
            # '3_low_realesrgan': {
            #     'low_RealESRGAN_x4_1': self.single_attackers['low_RealESRGAN_x4'],
            #     'low_RealESRGAN_x4_2': self.single_attackers['low_RealESRGAN_x4'],
            #     'low_RealESRGAN_x4_3': self.single_attackers['low_RealESRGAN_x4']
            # },
            # '4_low_realesrgan': {
            #     'low_RealESRGAN_x4_1': self.single_attackers['low_RealESRGAN_x4'],
            #     'low_RealESRGAN_x4_2': self.single_attackers['low_RealESRGAN_x4'],
            #     'low_RealESRGAN_x4_3': self.single_attackers['low_RealESRGAN_x4'],
            #     'low_RealESRGAN_x4_4': self.single_attackers['low_RealESRGAN_x4']
            # },
            # '5_low_realesrgan': {
            #     'low_RealESRGAN_x4_1': self.single_attackers['low_RealESRGAN_x4'],
            #     'low_RealESRGAN_x4_2': self.single_attackers['low_RealESRGAN_x4'],
            #     'low_RealESRGAN_x4_3': self.single_attackers['low_RealESRGAN_x4'],
            #     'low_RealESRGAN_x4_4': self.single_attackers['low_RealESRGAN_x4'],
            #     'low_RealESRGAN_x4_5': self.single_attackers['low_RealESRGAN_x4']
            # },
            # '2_low_adcsr': {
            #     'low_AdcSR_x4_1': self.single_attackers['low_AdcSR'],
            #     'low_AdcSR_x4_2': self.single_attackers['low_AdcSR']
            # },
            # '3_low_adcsr': {
            #     'low_AdcSR_x4_1': self.single_attackers['low_AdcSR'],
            #     'low_AdcSR_x4_2': self.single_attackers['low_AdcSR'],
            #     'low_AdcSR_x4_3': self.single_attackers['low_AdcSR']
            # },
            # '4_low_adcsr': {
            #     'low_AdcSR_x4_1': self.single_attackers['low_AdcSR'],
            #     'low_AdcSR_x4_2': self.single_attackers['low_AdcSR'],
            #     'low_AdcSR_x4_3': self.single_attackers['low_AdcSR'],
            #     'low_AdcSR_x4_4': self.single_attackers['low_AdcSR']
            # },
            # '5_low_adcsr': {
            #     'low_AdcSR_x4_1': self.single_attackers['low_AdcSR'],
            #     'low_AdcSR_x4_2': self.single_attackers['low_AdcSR'],
            #     'low_AdcSR_x4_3': self.single_attackers['low_AdcSR'],
            #     'low_AdcSR_x4_4': self.single_attackers['low_AdcSR'],
            #     'low_AdcSR_x4_5': self.single_attackers['low_AdcSR']
            # },
            # '2_high_adcsr': {
            #     'high_AdcSR_x4_1': self.single_attackers['high_AdcSR'],
            #     'high_AdcSR_x4_2': self.single_attackers['high_AdcSR']
            # },
            # '3_high_adcsr': {
            #     'high_AdcSR_x4_1': self.single_attackers['high_AdcSR'],
            #     'high_AdcSR_x4_2': self.single_attackers['high_AdcSR'],
            #     'high_AdcSR_x4_3': self.single_attackers['high_AdcSR']
            # },
            # '4_high_adcsr': {
            #     'high_AdcSR_x4_1': self.single_attackers['high_AdcSR'],
            #     'high_AdcSR_x4_2': self.single_attackers['high_AdcSR'],
            #     'high_AdcSR_x4_3': self.single_attackers['high_AdcSR'],
            #     'high_AdcSR_x4_4': self.single_attackers['high_AdcSR']
            # },
            # '5_high_adcsr': {
            #     'high_AdcSR_x4_1': self.single_attackers['high_AdcSR'],
            #     'high_AdcSR_x4_2': self.single_attackers['high_AdcSR'],
            #     'high_AdcSR_x4_3': self.single_attackers['high_AdcSR'],
            #     'high_AdcSR_x4_4': self.single_attackers['high_AdcSR'],
            #     'high_AdcSR_x4_5': self.single_attackers['high_AdcSR']
            # },
            # '2_high_realesrgan': {
            #     'high_RealESRGAN_x4_1': self.single_attackers['high_RealESRGAN_x4'],
            #     'high_RealESRGAN_x4_2': self.single_attackers['high_RealESRGAN_x4']
            # },
            # '3_high_realesrgan': {
            #     'high_RealESRGAN_x4_1': self.single_attackers['high_RealESRGAN_x4'],
            #     'high_RealESRGAN_x4_2': self.single_attackers['high_RealESRGAN_x4'],
            #     'high_RealESRGAN_x4_3': self.single_attackers['high_RealESRGAN_x4']
            # },
            # '4_high_realesrgan': {
            #     'high_RealESRGAN_x4_1': self.single_attackers['high_RealESRGAN_x4'],
            #     'high_RealESRGAN_x4_2': self.single_attackers['high_RealESRGAN_x4'],
            #     'high_RealESRGAN_x4_3': self.single_attackers['high_RealESRGAN_x4'],
            #     'high_RealESRGAN_x4_4': self.single_attackers['high_RealESRGAN_x4']
            # },
            # '5_high_realesrgan': {
            #     'high_RealESRGAN_x4_1': self.single_attackers['high_RealESRGAN_x4'],
            #     'high_RealESRGAN_x4_2': self.single_attackers['high_RealESRGAN_x4'],
            #     'high_RealESRGAN_x4_3': self.single_attackers['high_RealESRGAN_x4'],
            #     'high_RealESRGAN_x4_4': self.single_attackers['high_RealESRGAN_x4'],
            #     'high_RealESRGAN_x4_5': self.single_attackers['high_RealESRGAN_x4']
            # }
        }
        cases.update({
            # 'noise_then_high_sr': {
            #     'gaussian_noise_0.03': GaussianNoiseAttacker(std=0.03),
            #     'high_RealESRGAN_x4': self.single_attackers['high_RealESRGAN_x4']
            # },
            # 'noise_005_then_high_sr': {
            #     'gaussian_noise_0.05': GaussianNoiseAttacker(std=0.05),
            #     'high_RealESRGAN_x4': self.single_attackers['high_RealESRGAN_x4']
            # },
            # 'blur_then_high_sr': {
            #     'gaussian_blur': GaussianBlurAttacker(kernel_size=5, sigma=1),
            #     'high_RealESRGAN_x4': self.single_attackers['high_RealESRGAN_x4']
            # },
            # 'jpeg_50_then_high_sr': {
            #     'jpeg_attacker_5': JPEGAttacker(quality=50),
            #     'high_RealESRGAN_x4': self.single_attackers['high_RealESRGAN_x4']
            # },
            # 'jpeg_then_high_sr': {
            #     'jpeg_attacker_5': JPEGAttacker(quality=5),
            #     'high_RealESRGAN_x4': self.single_attackers['high_RealESRGAN_x4']
            # },
            # 'rotate_then_high_sr':{
            #     'rotate_45': RotateAttacker(degree=45),
            #     'high_RealESRGAN_x4': self.single_attackers['high_RealESRGAN_x4']
            # },
            # 'rotate_15_then_high_sr': {
            #     'rotate_15': RotateAttacker(degree=15),
            #     'high_RealESRGAN_x4': self.single_attackers['high_RealESRGAN_x4']
            # },
            # 'diffu_then_high_sr': {
            #     'diff_attacker_60': self.single_attackers['diff_attacker_60'],
            #     'high_RealESRGAN_x4': self.single_attackers['high_RealESRGAN_x4']
            # },
            # 'diffu_20_then_high_sr': {
            #     'diff_attacker_20': self.single_attackers['diff_attacker_20'],
            #     'high_RealESRGAN_x4': self.single_attackers['high_RealESRGAN_x4']
            # },
            # 'bm3d_then_high_sr': {
            #     'bm3d': self.single_attackers['bm3d'],
            #     'high_RealESRGAN_x4': self.single_attackers['high_RealESRGAN_x4']
            # },
            'bmshj_then_high_sr': {
                'bmshj2018-factorized_5': self.single_attackers['bmshj2018-factorized_5'],
                'high_RealESRGAN_x4': self.single_attackers['high_RealESRGAN_x4']
            },
            'cheng_then_high_sr': {
                'cheng2020-anchor_5': self.single_attackers['cheng2020-anchor_5'],
                'high_RealESRGAN_x4': self.single_attackers['high_RealESRGAN_x4']
            },
            # 'noise_then_low_sr': {
            #     'gaussian_noise_0.03': GaussianNoiseAttacker(std=0.03),
            #     'low_RealESRGAN_x4': self.single_attackers['low_RealESRGAN_x4']
            # },
            # 'noise_005_then_low_sr': {
            #     'gaussian_noise_0.05': GaussianNoiseAttacker(std=0.05),
            #     'low_RealESRGAN_x4': self.single_attackers['low_RealESRGAN_x4']
            # },
            # 'blur_then_low_sr': {
            #     'gaussian_blur': GaussianBlurAttacker(kernel_size=5, sigma=1),
            #     'low_RealESRGAN_x4': self.single_attackers['low_RealESRGAN_x4']
            # },
            # 'jpeg_then_low_sr': {
            #     'jpeg_attacker_5': JPEGAttacker(quality=5),
            #     'low_RealESRGAN_x4': self.single_attackers['low_RealESRGAN_x4']
            # },
            # 'jpeg_50_then_low_sr': {
            #     'jpeg_attacker_5': JPEGAttacker(quality=50),
            #     'low_RealESRGAN_x4': self.single_attackers['low_RealESRGAN_x4']
            # },
            # 'rotate_then_low_sr':{
            #     'rotate_45': RotateAttacker(degree=45),
            #     'low_RealESRGAN_x4': self.single_attackers['low_RealESRGAN_x4']
            # },
            # 'rotate_15_then_low_sr': {
            #     'rotate_45': RotateAttacker(degree=15),
            #     'low_RealESRGAN_x4': self.single_attackers['low_RealESRGAN_x4']
            # },
            # 'diffu_then_low_sr': {
            #     'diff_attacker_60': self.single_attackers['diff_attacker_60'],
            #     'low_RealESRGAN_x4': self.single_attackers['low_RealESRGAN_x4']
            # },
            # 'diffu_20_then_low_sr': {
            #     'diff_attacker_20': self.single_attackers['diff_attacker_20'],
            #     'low_RealESRGAN_x4': self.single_attackers['low_RealESRGAN_x4']
            # },
            # 'bm3d_then_low_sr': {
            #     'bm3d': self.single_attackers['bm3d'],
            #     'low_RealESRGAN_x4': self.single_attackers['low_RealESRGAN_x4']
            # }
        })
        # cases.update({
        #     'noise_then_high_sr': {
        #         'gaussian_noise_0.03': GaussianNoiseAttacker(std=0.03),
        #         'SwinIR_x4': self.single_attackers['high_SwinIR_x4']
        #     },
        #     'sr_then_noise': {
        #         'SwinIR_x4': self.single_attackers['high_SwinIR_x4'],
        #         'gaussian_noise_0.03': GaussianNoiseAttacker(std=0.03)
        #     }
        # })

        return cases

    def get_image_paths(self) -> List[str]:
        """
        Get paths of watermarked images in the watermarked directory.

        Returns:
            List of watermarked image file paths
        """
        image_files = [f for f in os.listdir(self.watermarked_dir)
                       if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
        if self.debug:
            image_files = image_files[:5]
        return [os.path.join(self.watermarked_dir, img) for img in image_files]

    def skip_attack(self, out_paths, watermarked_paths):
        # Filter out paths where the output file already exists
        filtered_watermarked_paths = [
            img_path for img_path, out_path in zip(watermarked_paths, out_paths)
            if not os.path.exists(out_path)
        ]
        filtered_out_paths = [
            out_path for out_path in out_paths
            if not os.path.exists(out_path)
        ]
        return filtered_watermarked_paths,filtered_out_paths

    def run_single_attacks(self, attackers: Optional[Dict[str, WMAttacker]] = None):
        """
        Run single attacks on watermarked images.

        Args:
            attackers: Dictionary of attacker objects (uses self.single_attackers if None)
        """
        attackers = attackers or self.single_attackers
        watermarked_paths = self.get_image_paths()

        for attacker_name, attacker in attackers.items():
            print(f"Running single attack: {attacker_name}")

            # Create output directory for this attack
            attack_out_dir = os.path.join(self.output_dir, attacker_name)
            os.makedirs(attack_out_dir, exist_ok=True)

            # Create output paths
            out_paths = [os.path.join(attack_out_dir, os.path.basename(img_path))
                         for img_path in watermarked_paths]
            filtered_watermarked_paths, filtered_out_paths  = self.skip_attack(out_paths, watermarked_paths)

            if not filtered_watermarked_paths:
                print(f"  All images already processed for attacker: {attacker_name}")
                continue
            # Run the attack
            attacker.attack(filtered_watermarked_paths, filtered_out_paths)

            # Evaluate if requested
            self._evaluate_and_log(attacker_name, attack_out_dir, watermarked_paths, out_paths)



    def run_combined_attacks(self, combined_cases: Optional[Dict[str, Dict[str, WMAttacker]]] = None):
        """
        Run combined attack scenarios on watermarked images.

        Args:
            combined_cases: Dictionary of combined attack scenarios (uses self.combined_cases if None)
        """
        combined_cases = combined_cases or self.combined_cases
        watermarked_paths = self.get_image_paths()

        for case_name, attackers in combined_cases.items():
            print(f"Running combined attack case: {case_name}")
            # Create output directory for this combined attack
            case_out_dir = os.path.join(self.output_dir, f"combined_{case_name}")
            os.makedirs(case_out_dir, exist_ok=True)

            # Create output paths
            out_paths = [os.path.join(case_out_dir, os.path.basename(img_path))
                         for img_path in watermarked_paths]
            filtered_watermarked_paths, filtered_out_paths  = self.skip_attack(out_paths, watermarked_paths)
            # Initialize with copying watermarked images
            for wm_path, out_path in zip(filtered_watermarked_paths, filtered_out_paths):
                img = Image.open(wm_path)
                os.makedirs(os.path.dirname(out_path), exist_ok=True)
                img.save(out_path)

            # Apply each attacker in sequence
            for i, (attacker_name, attacker) in enumerate(attackers.items()):
                print(f"  Step {i + 1}: Applying {attacker_name}")
                attacker.attack(out_paths, out_paths)

            # Evaluate if requested
            self._evaluate_and_log(f"combined_{case_name}", case_out_dir, watermarked_paths, out_paths)

    def _evaluate_and_log(self,
                          attack_name: str,
                          attack_out_dir: str,
                          watermarked_paths: List[str],
                          out_paths: List[str]):
        """
        Evaluate attack results and log metrics and images.

        Args:
            attack_name: Name of the attack
            attack_out_dir: Directory with attack outputs
            watermarked_paths: Paths to watermarked images
            out_paths: Paths to attacked images
        """
        if not self.metrics_evaluator:
            return

        # Set directories for the evaluator
        self.metrics_evaluator.set_directories(
            original_dir=self.original_dir,
            watermarked_dir=self.watermarked_dir,
            attacked_dir=attack_out_dir
        )

        # Evaluate metrics
        metrics = self.metrics_evaluator.evaluate_all_metrics()

        # Log metrics to wandb if available
        try:
            if wandb.run is not None:
                wandb.log({f"{attack_name}/{k}": v for k, v in metrics.items()})

                # Log sample images to wandb every 100 images
                for i, (wm_path, att_path) in enumerate(zip(watermarked_paths, out_paths)):
                    if i % 100 == 0 and os.path.exists(att_path):
                        try:
                            wm_img = Image.open(wm_path)
                            att_img = Image.open(att_path)

                            # Ensure images are the same size for horizontal stacking
                            if wm_img.size != att_img.size:
                                att_img = att_img.resize(wm_img.size)

                            wandb.log({
                                f"{attack_name}/sample_{i}": wandb.Image(
                                    np.hstack([np.array(wm_img), np.array(att_img)]),
                                    caption=f"Left: Watermarked, Right: Attacked"
                                )
                            })
                        except Exception as e:
                            print(f"Error creating comparison image: {e}")
        except ImportError:
            print("Warning: wandb not available for logging metrics")

    def run_all_attacks(self):
        """Run all single and combined attacks."""
        self.run_single_attacks()
        self.run_combined_attacks()

